"""
Implements the PoseCNN network architecture in PyTorch.
"""
import math
import os,sys
import numpy as np
from PIL import Image
sys.path.append(os.getcwd())
import torch
import torch.nn as nn
from torchvision.ops import RoIPool
from torchvision.utils import save_image, make_grid

from InitPose.lib.utils import quaternion_to_matrix
from InitPose.lib.loss.loss import HoughVoting, _LABEL2MASK_THRESHOL, \
    loss_cross_entropy, loss_Rotation, IOUselection,loss_rotation

from InitPose.models.gfnet  import Block
# from visualizer.util_visualizer import tensor2im,save_image

class FeatureExtraction(nn.Module):
    """
    Feature Embedding Module for PoseCNN. Using pretrained VGG16 network as backbone.
    """    
    def __init__(self, pretrained_model):
        super(FeatureExtraction, self).__init__()
        embedding_layers = list(pretrained_model.features)[:30]
        ## Embedding Module from begining till the first output feature map
        self.embedding1 = nn.Sequential(*embedding_layers[:23])
        ## Embedding Module from the first output feature map till the second output feature map
        self.embedding2 = nn.Sequential(*embedding_layers[23:])

        for i in [0, 2, 5, 7, 10, 12, 14]:
            self.embedding1[i].weight.requires_grad = False
            self.embedding1[i].bias.requires_grad = False
    
    def forward(self, datadict):
        """
        feature1: [bs, 512, H/8, W/8]
        feature2: [bs, 512, H/16, W/16]
        """ 
        feature1 = self.embedding1(datadict['rgb'])
        feature2 = self.embedding2(feature1)
        
        return feature1, feature2

class SegmentationBranch(nn.Module):
    """
    Instance Segmentation Module for PoseCNN. 
    """    
    def __init__(self, num_classes = 1, hidden_layer_dim = 64):
        super(SegmentationBranch, self).__init__()


        seg_branch = []
        seg_branch.append(nn.Conv2d(512,hidden_layer_dim,1,bias=True)) #for 1x1
        seg_branch.append(nn.ReLU(inplace=True)) #for RelU

        self.num_classes = num_classes + 1

        #aux function to initialize all conv2d with kaiming
        def init_weights(m):
          if isinstance(m,nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight)
            m.bias.data.fill_(0.0)

        self.seg_branch = nn.Sequential(*seg_branch)
        self.seg_branch.apply(init_weights)

        
        self.conv_class = nn.Conv2d(64,num_classes+1,1,bias=True)

        torch.nn.init.kaiming_normal_(self.conv_class.weight)
        self.conv_class.bias.data.fill_(0.0)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, feature1, feature2):
        """
        Args:
            feature1: Features from feature extraction backbone (B, 512, h, w)
            feature2: Features from feature extraction backbone (B, 512, h//2, w//2)
        Returns:
            probability: Segmentation map of probability for each class at each pixel.
                probability size: (B,num_classes+1,H,W)
            segmentation: Segmentation map of class id's with highest prob at each pixel.
                segmentation size: (B,H,W)
            bbx: Bounding boxs detected from the segmentation. Can be extracted 
                from the predicted segmentation map using self.label2bbx(segmentation).
                bbx size: (N,6) with (batch_ids, x1, y1, x2, y2, cls)
        """
        probability = None
        segmentation = None
        bbx = None
        num_classes = self.num_classes

        f1 = self.seg_branch(feature1)
        f2 = self.seg_branch(feature2)
        #print("f1 has shape after 1x1",f1.shape)
        #print("f2 has shape after 1x1",f2.shape)
        up_f2 = torch.nn.Upsample(scale_factor = 2,mode="bilinear")
        f2_upsampled = up_f2(f2)
        concat_ft = f1 + f2_upsampled

        up_cat = torch.nn.Upsample(scale_factor=8)
        concat_upsampled = up_cat(concat_ft)
        probs = self.conv_class(concat_upsampled)
        
        probability = nn.functional.softmax(probs,dim=1)
        segmentation = torch.max(probability,dim=1)[1]
        grid = make_grid(probability)
        save_image(grid, 'out/output.png')
        # print(segmentation.shape)
        # save_image(segmentation, 'img/output.png')
        #print(segmentation[0])
        bbx = self.label2bbx(segmentation)
  

        return probability, segmentation, bbx
    
    def label2bbx(self, label):
        bbx = []
        bs, H, W = label.shape
        device = label.device
        label_repeat = label.view(bs, 1, H, W).repeat(1, self.num_classes, 1, 1).to(device)
        label_target = torch.linspace(0, self.num_classes - 1, steps = self.num_classes).view(1, -1, 1, 1).repeat(bs, 1, H, W).to(device)
        mask = (label_repeat == label_target)
        for batch_id in range(mask.shape[0]):
            for cls_id in range(mask.shape[1]):
                if cls_id != 0: 
                    # cls_id == 0 is the background
                    y, x = torch.where(mask[batch_id, cls_id] != 0)
                    if y.numel() >= _LABEL2MASK_THRESHOL:
                        bbx.append([batch_id, torch.min(x).item(), torch.min(y).item(), 
                                    torch.max(x).item(), torch.max(y).item(), cls_id])
        return torch.tensor(bbx).to(device)
        
        
class TranslationBranch(nn.Module):
    """
    3D Translation Estimation Module for PoseCNN. 
    """    
    def __init__(self, num_classes = 1, hidden_layer_dim = 128):
        super(TranslationBranch, self).__init__()
        

        trans_branch = []
        trans_branch.append(nn.Conv2d(512,hidden_layer_dim,1,bias = True))
        trans_branch.append(nn.ReLU(inplace=True))

        def init_weights(m):
          if isinstance(m,nn.Conv2d):
            torch.nn.init.kaiming_normal_(m.weight)
            m.bias.data.fill_(0.0)

        self.trans_branch = nn.Sequential(*trans_branch)
        self.trans_branch.apply(init_weights)

        self.conv_class = nn.Conv2d(128,3*num_classes,1,bias=True)

        torch.nn.init.kaiming_normal_(self.conv_class.weight)
        self.conv_class.bias.data.fill_(0.0)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, feature1, feature2):
        """
        Args:
            feature1: Features from feature extraction backbone (B, 512, h, w)
            feature2: Features from feature extraction backbone (B, 512, h//2, w//2)
        Returns:
            translation: Map of object centroid predictions.
                translation size: (N,3*num_classes,H,W)
        """
        translation = None
        # Replace "pass" statement with your code
        f1 = self.trans_branch(feature1)
        f2 = self.trans_branch(feature2)

        up_f2 = torch.nn.Upsample(scale_factor=2,mode="bilinear")
        f2_upsampled = up_f2(f2)

        concat_ft = f1 + f2_upsampled

        up_cat = torch.nn.Upsample(scale_factor=8)
        concat_upsampled = up_cat(concat_ft)

        return self.conv_class(concat_upsampled)

class RotationBranch(nn.Module):
    """
    3D Rotation Regression Module for PoseCNN. 
    """    
    def __init__(self, using_fft=True,feature_dim = 512, roi_shape = 7, hidden_dim = 4096, num_classes = 1):
        super(RotationBranch, self).__init__()
        self.using_fft = using_fft
        self.fft= Block(dim=feature_dim,h=roi_shape,w=int(roi_shape/2+0.5))

        self.ROI_1 = RoIPool((roi_shape,roi_shape),1/8)
        self.ROI_2 = RoIPool((roi_shape,roi_shape),1/16)
        self.quat = nn.Sequential(
          nn.Linear(in_features = feature_dim*roi_shape*roi_shape,out_features=4096),
          nn.ReLU(inplace=True),
          nn.Linear(in_features=4096, out_features=4096),
          nn.ReLU(inplace=True),
          nn.Linear(in_features=4096,out_features=4*num_classes),
        )

        def init_weights(m):
          if isinstance(m,nn.Linear):
            m.bias.data.fill_(0.0)

        self.quat.apply(init_weights)

    def forward(self, feature1, feature2, bbx):
        """
        Args:
            feature1: Features from feature extraction backbone (B, 512, h, w)
            feature2: Features from feature extraction backbone (B, 512, h//2, w//2)
            bbx: Bounding boxes of regions of interst (N, 5) with (batch_ids, x1, y1, x2, y2)
        Returns:
            quaternion: Regressed components of a quaternion for each class at each ROI.
                quaternion size: (N,4*num_classes)
        """

        # Replace "pass" statement with your code
        rois = bbx.to(feature1.dtype)
        ft_1_roi = self.ROI_1(feature1,rois)
        
        if self.using_fft:
            ft_1_roi = self.fft(self.review(ft_1_roi))
            ft_1_roi = self.transreview(ft_1_roi)  

        ft_2_roi = self.ROI_2(feature2,rois)
        
        
        if self.using_fft:
            ft_2_roi = self.fft(self.review(ft_2_roi))
            ft_2_roi = self.transreview(ft_2_roi)
            
        ft_cat = ft_1_roi + ft_2_roi
        
        return self.quat(ft_cat.flatten(1))
    def review(self,feature):
        b,c,h,w = feature.shape
        feature = feature.view(b,h*w,c)
        return feature
    
    def transreview(self,feature):
        b,n,c= feature.shape
        feature = feature.view(b,c,int(math.sqrt(n)),int(math.sqrt(n)))
        return feature
        
        
class PoseCNN(nn.Module):
    """
    PoseCNN
    """
    def __init__(self, pretrained_backbone, models_pcd, cam_intrinsic,using_fft,using_loss_model):
        super(PoseCNN, self).__init__()
        self.using_fft =using_fft
        self.using_loss_model = using_loss_model
        
        self.iou_threshold = 0.7
        self.models_pcd = models_pcd
        self.cam_intrinsic = cam_intrinsic


        self.FeatureExtraction = FeatureExtraction(pretrained_backbone)
        self.SegmentationBranch = SegmentationBranch()
        self.TranslationBranch = TranslationBranch()
        self.RotationBranch = RotationBranch(using_fft=self.using_fft)

    def forward(self, input_dict):
        """
        input_dict = {
            'rgb',
            'depth',
            'objs_id',
            'mask',
            'bbx',
            'RTs'
        }
        """
        if self.training:
            loss_dict = {
                "loss_segmentation": 0,
                "loss_centermap": 0,
                "loss_R": 0
            }

            gt_bbx = self.getGTbbx(input_dict)

            feature1,feature2 = self.FeatureExtraction(input_dict)
            probs,seg,bbx = self.SegmentationBranch(feature1,feature2)

            translation = self.TranslationBranch(feature1,feature2) #(2,30,480,640)
            # print(f"{probs.shape} probs shape")
            gt_labels = input_dict["label"]
            gt_center = input_dict["centermaps"]
            #print("gt_lables",bbx.shape)
            loss_dict["loss_segmentation"] = loss_cross_entropy(probs,gt_labels)

            #declare l1 loss for center xyz
            center_loss =  torch.nn.L1Loss()

            #get the translation prediction and groundtruth
            #pred_T = self.estimateTrans(translation,bbx,gt_labels)
            #print("predicted_translation",pred_T)
            #gt_T = self.gtTrans(filtered_bbx,input_dict)
            #trans_loss = center_loss(pred_T,gt_T)
            #loss_dict["loss_centermap"] = trans_loss

            #this one or the previous one 
            trans_loss = center_loss(translation,gt_center)
            loss_dict["loss_centermap"] = trans_loss

            #filter your boundary boxes
            filtered_bbx = IOUselection(bbx,gt_bbx,self.iou_threshold)

            if len(filtered_bbx != 0):
              #print("labels",labels)
                quaternion_map = self.RotationBranch(feature1,feature2,bbx[:,:5])
                gt_rot = self.gtRotation(filtered_bbx,input_dict)
                Pred_rot,labels = self.estimateRotation(quaternion_map,filtered_bbx)
                if self.using_loss_model:
                    rot_loss = loss_Rotation(Pred_rot,gt_rot,labels,self.models_pcd)
                else:
                    rot_loss = loss_rotation(Pred_rot,gt_rot)
                loss_dict["loss_R"] = rot_loss
            
            return loss_dict
        else:
            output_dict = None
            segmentation = None

            with torch.no_grad():

                feature1,feature2 = self.FeatureExtraction(input_dict)
                probs,segmentation,bbx = self.SegmentationBranch(feature1,feature2)
                translation_map = self.TranslationBranch(feature1,feature2) #(2,30,480,640)

                pred_centers, pred_depths = HoughVoting(segmentation, translation_map)

                quaternion_map = self.RotationBranch(feature1, feature2, bbx[:, :5])
                pred_Rs, lbl = self.estimateRotation(quaternion_map, bbx)
                
                output_dict = self.generate_pose(pred_Rs, pred_centers, pred_depths, bbx)

            return output_dict, segmentation
    
    def estimateTrans(self, translation_map, filter_bbx, pred_label):
        """
        translation_map: a tensor [batch_size, num_classes * 3, height, width]
        filter_bbx: N_filter_bbx * 6 (batch_ids, x1, y1, x2, y2, cls)
        label: a tensor [batch_size, num_classes, height, width]
        """
        N_filter_bbx = filter_bbx.shape[0]
        pred_Ts = torch.zeros(N_filter_bbx, 3)
        for idx, bbx in enumerate(filter_bbx):
            batch_id = int(bbx[0].item())
            cls = int(bbx[5].item())
            trans_map = translation_map[batch_id, (cls-1) * 3 : cls * 3, :]
            label = (pred_label[batch_id] == cls).detach()
            pred_T = trans_map[:, label].mean(dim=1)
            pred_Ts[idx] = pred_T
        return pred_Ts

    def gtTrans(self, filter_bbx, input_dict):
        N_filter_bbx = filter_bbx.shape[0]
        gt_Ts = torch.zeros(N_filter_bbx, 3)
        for idx, bbx in enumerate(filter_bbx):
            batch_id = int(bbx[0].item())
            cls = int(bbx[5].item())
            gt_Ts[idx] = input_dict['RTs'][batch_id][cls - 1][:3, [3]].T
        return gt_Ts 

    def getGTbbx(self, input_dict):
        """
            bbx is N*6 (batch_ids, x1, y1, x2, y2, cls)
        """
        gt_bbx = []
        objs_id = input_dict['objs_id']
        device = objs_id.device
        ## [x_min, y_min, width, height]
        bbxes = input_dict['bbx']
        for batch_id in range(bbxes.shape[0]):
            for idx, obj_id in enumerate(objs_id[batch_id]):
                if obj_id.item() != 0:
                    # the obj appears in this image
                    bbx = bbxes[batch_id][idx]
                    gt_bbx.append([batch_id, bbx[0].item(), bbx[1].item(),
                                  bbx[2].item(), bbx[3].item(), obj_id.item()])
        return torch.tensor(gt_bbx).to(device=device, dtype=torch.int16)
        
    def estimateRotation(self, quaternion_map, filter_bbx):
        """
        quaternion_map: a tensor [batch_size, num_classes * 3, height, width]
        filter_bbx: N_filter_bbx * 6 (batch_ids, x1, y1, x2, y2, cls)
        """
        N_filter_bbx = filter_bbx.shape[0]
        pred_Rs = torch.zeros(N_filter_bbx, 3, 3)
        label = []
        for idx, bbx in enumerate(filter_bbx):
            batch_id = int(bbx[0].item())
            cls = int(bbx[5].item())
            quaternion = quaternion_map[idx, (cls-1) * 4 : cls * 4]
            quaternion = nn.functional.normalize(quaternion, dim=0)
            pred_Rs[idx] = quaternion_to_matrix(quaternion)
            label.append(cls)
        label = torch.tensor(label)
        return pred_Rs, label

    def gtRotation(self, filter_bbx, input_dict):
        N_filter_bbx = filter_bbx.shape[0]
        gt_Rs = torch.zeros(N_filter_bbx, 3, 3)
        for idx, bbx in enumerate(filter_bbx):
            batch_id = int(bbx[0].item())
            cls = int(bbx[5].item())
            gt_Rs[idx] = input_dict['RTs'][batch_id][cls - 1][:3, :3]
        return gt_Rs 

    def generate_pose(self, pred_Rs, pred_centers, pred_depths, bbxs):
        """
        pred_Rs: a tensor [pred_bbx_size, 3, 3]
        pred_centers: [batch_size, num_classes, 2]
        pred_depths: a tensor [batch_size, num_classes]
        bbx: a tensor [pred_bbx_size, 6]
        """        
        output_dict = {}
        for idx, bbx in enumerate(bbxs):
            bs, _, _, _, _, obj_id = bbx
            R = pred_Rs[idx].numpy()
            center = pred_centers[bs, obj_id - 1].numpy()
            depth = pred_depths[bs, obj_id - 1].numpy()
            if (center**2).sum().item() != 0:
                T = np.linalg.inv(self.cam_intrinsic) @ np.array([center[0], center[1], 1]) * depth
                T = T[:, np.newaxis]
                if bs.item() not in output_dict:
                    output_dict[bs.item()] = {}
                output_dict[bs.item()][obj_id.item()] = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]])))
        return output_dict


if __name__ == "__main__":
    bbx = torch.rand(2,5)
    feature1 = torch.randn(2,512,60,80)
    feature2 = torch.randn(2,512,30,40)
    model= RotationBranch()
    output = model(feature1,feature2,bbx)
    print(output.shape)